from copy import deepcopy

import torch
import torch.nn.functional as F


def get_loss(loss_type: str = 'mse'):
    if loss_type in ['mse', 'l2']:
        return F.mse_loss
    elif loss_type == 'l1':
        return F.l1_loss

def get_single_head(head_kwargs, task_name):
    """ Return the head layer"""
    from .network import HeadNetwork
    return HeadNetwork(head_kwargs['in_f'],
                       head_kwargs['hidden'],
                       head_kwargs['out_f'],
                       dropout=head_kwargs['dropout'],
                       task_name=task_name)


def get_heads(head_kwargs, task_names):
    heads = torch.nn.ModuleDict(
        {task: get_single_head(head_kwargs, task) for task in task_names})
    return heads


def get_single_bottleneck(bottleneck_kwargs, task_name):
    """ Return the bottlenect layer"""
    from .network import BottleneckNetwork
    return BottleneckNetwork(bottleneck_kwargs['in_f'],
                             bottleneck_kwargs['hidden'],
                             bottleneck_kwargs['dropout'],
                             task_name=task_name)


def get_bottlenecks(bottleneck_kwargs, task_names):
    bottlenecks = torch.nn.ModuleDict(
        {task: get_single_bottleneck(bottleneck_kwargs, task) for task in task_names})
    return bottlenecks


def get_single_backbone(backbone_kwargs, node=False):
    """ Return the backbones"""
    from .network import SingleNetwork
    return SingleNetwork(backbone_kwargs['in_f'],
                         backbone_kwargs['out_f'],
                         depth=backbone_kwargs['depth'],
                         heads=backbone_kwargs['heads'],
                         dropout=backbone_kwargs['dropout'],
                         node=node)


def get_backbones(backbone_kwargs, task_names):
    backbones = torch.nn.ModuleDict(
        {task: get_single_backbone(backbone_kwargs) for task in task_names})
    return backbones


def get_single_backbone_front(backbone_kwargs):
    """ Return the frontend of backbone layer"""
    from .network import SingleNetwork_front
    return SingleNetwork_front(backbone_kwargs['in_f'],
                               backbone_kwargs['out_f'],
                               depth=backbone_kwargs['depth']['front'])
    

def get_molrep_front(molrep_kwargs):
    """ Return the frontend of backbone layer"""
    from .network import MolRep_front
    return MolRep_front(molrep_kwargs['rep_path'])


def get_backbones_front(backbone_kwargs, task_names):
    backbones = torch.nn.ModuleDict(
        {task: get_single_backbone_front(backbone_kwargs) for task in task_names})
    return backbones


def get_single_backbone_back(backbone_kwargs):
    """ Return the backend of backbone layer"""
    from .network import SingleNetwork_back
    return SingleNetwork_back(backbone_kwargs['in_f'],
                               backbone_kwargs['out_f'],
                               depth=backbone_kwargs['depth']['back'],
                               heads=backbone_kwargs['heads'],
                               dropout=backbone_kwargs['dropout'])


def get_backbones_back(backbone_kwargs, task_names):
    backbones = torch.nn.ModuleDict(
        {task: get_single_backbone_back(backbone_kwargs) for task in task_names})
    return backbones

def get_single_transform(transform_kwargs):
    """ Return the transform network"""
    from .network import transform_network
    return transform_network(latent_size=transform_kwargs['latent_size'], net_width=transform_kwargs['net_width'])


def get_transforms(transform_kwargs, task_names):
    transforms = torch.nn.ModuleDict(
        {task: get_single_transform(transform_kwargs) for task in task_names})
    return transforms

def get_inv_transforms(backbone_kwargs, task_names):
    backbones = torch.nn.ModuleDict(
        {task: get_single_transform(backbone_kwargs) for task in task_names})
    return backbones


def get_test_model(backbone_kwargs, bottleneck_kwargs, head_kwargs, setup, molrep_kwargs='', transform_kwargs=None, task_names=None, node=False, d_num=None, perb_ratio=None):
    """ Return the test model for inferencing"""
    molrep_kwargs = deepcopy(molrep_kwargs)
    backbone_kwargs = deepcopy(backbone_kwargs)
    bottleneck_kwargs = deepcopy(bottleneck_kwargs)
    transform_kwargs = deepcopy(transform_kwargs)
    head_kwargs = deepcopy(head_kwargs)

    backbone_front = get_single_backbone_front(backbone_kwargs)
    backbone_back =  get_backbones_back(backbone_kwargs, task_names)
    bottlenecks = get_bottlenecks(bottleneck_kwargs, task_names)
    heads = get_heads(head_kwargs, task_names)
    transform = get_transforms(transform_kwargs, task_names)
    inv = get_transforms(transform_kwargs, task_names)
    from .test_model import GATEModel
    model = GATEModel(backbone_front, backbone_back, bottlenecks, transform, inv, heads, d_num, perb_ratio)

    return model
